//===----------------------------------------------------------------------===//
//
// Copyright (C) 2022 Sophgo Technologies Inc.  All rights reserved.
//
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
// third-party components.
//
//===----------------------------------------------------------------------===//

#include "tpu_mlir/Conversion/TopToTpu/LoweringBM1684.h"
#include "tpu_mlir/Support/Dnnl/Dnnl.h"

namespace tpu_mlir {
namespace bm1684 {

void ConvLowering::LoweringF32(PatternRewriter &rewriter,
                               top::ConvOp op) const {
  std::vector<Value> operands;
  const int nInputs = op->getNumOperands();
  for (auto i = 0; i < nInputs; ++i) {
    operands.push_back(op->getOperand(i));
  }
  std::vector<NamedAttribute> attrs;
  for (auto &attr : op->getAttrs()) {
    attrs.push_back(attr);
  }
  bool with_bias = !module::isNone(op.getBias());
  attrs.push_back(
      rewriter.getNamedAttr("with_bias", rewriter.getBoolAttr(with_bias)));

  if (op.getKernelShape().size() == 3) {
    rewriter.replaceOpWithNewOp<tpu::Conv3DOp>(op, op.getOutput().getType(),
                                               operands, attrs);
  } else {
    rewriter.replaceOpWithNewOp<tpu::Conv2DOp>(op, op.getOutput().getType(),
                                               operands, attrs);
  }
}

void ConvLowering::LoweringINT8(PatternRewriter &rewriter, top::ConvOp op,
                                bool asymmetric) const {
  if (module::isWeight(op.getFilter()) == false) {
    LoweringF32(rewriter, op);
    return;
  }

  std::vector<Value> operands;
  operands.push_back(op.getInput());
  std::vector<NamedAttribute> attrs;
  auto attr = op.parseParam();
  auto filterOp = cast<top::WeightOp>(op.getFilter().getDefiningOp());
  auto filter_f32 = filterOp.read<float>();
  int filter_elem_num = module::getNumElements(filterOp);

  double in_scale, out_scale;
  int64_t in_zp, out_zp;
  module::getScaleAndZeroPoint(op.getInput(), in_scale, in_zp, asymmetric);
  module::getScaleAndZeroPoint(op.getOutput(), out_scale, out_zp, asymmetric);

  auto filter_max = findMaxabs(filter_f32->data(), filter_f32->size());
  int rshift = calRightShiftNum(filter_max, in_scale, out_scale, BITS_INT8);

  if (rshift < 0 || attr.sw > 15 || attr.sh > 15) {
    // lowring as fp32
    LoweringF32(rewriter, op);
    return;
  }

  auto groups = attr.groups;
  bool use_wino = attr.kh == 3 && attr.kw == 3 && attr.dh == 1 &&
                  attr.dw == 1 && attr.sh == 1 && attr.sw == 1 &&
                  attr.ow * attr.oh >= 100 &&
                  !(groups == attr.ic && groups == attr.oc && groups > 1) &&
                  (groups > 1 ? (attr.oc / groups) % 64 /*NPU_NUM*/ == 0 : 1);
  use_wino = use_wino && op.getDoWinograd().value_or(false);

  // lowring bias
  std::shared_ptr<std::vector<int16_t>> bias_int16;
  if (attr.has_bias) {
    auto biasOp = cast<top::WeightOp>(op.getBias().getDefiningOp());
    auto bias_fp32 = biasOp.read<float>();
    int bias_len = bias_fp32->size();
    bias_int16 = std::make_shared<std::vector<int16_t>>(use_wino ? bias_len * 2
                                                                 : bias_len);

    float bias_scale = 1.0 * (1 << rshift) / out_scale;
    float overflow_ratio = quantizeToInt16(
        bias_fp32->data(), bias_int16->data(), bias_len, bias_scale);

    int rightShiftDec = 2;
    while (overflow_ratio > 0.03 && rshift > 0) {
      rshift--;
      bias_scale = 1.0 * (1 << rshift) / out_scale;
      overflow_ratio = quantizeToInt16(bias_fp32->data(), bias_int16->data(),
                                       bias_len, bias_scale);
      rightShiftDec--;
    }
  }

  // lowring weight
  std::vector<int64_t> rshift_v;
  rshift_v.push_back(rshift);
  std::vector<int64_t> multiplier_v;
  multiplier_v.push_back(1);
  float scale = 1.0 * (1 << rshift) * in_scale / out_scale;
  auto filter_int8 = std::make_shared<std::vector<int8_t>>(filter_f32->size());
  quantizeToInt8(filter_f32->data(), filter_int8->data(), filter_f32->size(),
                 scale);

  std::vector<int64_t> filter_shape = module::getShape(filterOp);

  /**winograd optimization*/
  if (use_wino) {
    /**
     * winograd quantize to int8
     */
    int winoRightShiftBits;
    int num_connection = filter_elem_num / 9;
    int tmp;

    /**
     * 1684 winograd backend op only support F(2,3), so only icxocx3x3 filter
     can join the winograd optimization.
     * which means 3x3 filter will be convert in to 4x4 matrix by batched G @
     weight,
     * where G = [[1, 0, 0],
                  [0.5, 0.5, 0.5],
                  [0.5, -0.5, 0.5],
                  [0, 0, 1]],
     *  which is a const generated by F(2,3)
     * dnnl MatMul do not support batched matmul, so bmm(G, weight) should be
     implemted by flatten kw,kh and ic,oc,
     * make kernel shape as (ic*oc, 3*3)
     *
     * then the matmul format is
     *            M     K     K  N
     *        [(ic*oc), 9] @ [9, 16] -> [(ic*oc), 16]
     *    (kernel              C.T
     *      .reshape(-1, 9))
     *
     * equals [16, 9] @ [9, (ic*oc)] -> [(ic*oc), 16]
     *          C       kernel.reshape(-1, 9)
     *                        .transpose(0, 1)
     *
     * equals     G  @  kernel @ G.T -> gt
    */
    float *fY = new float[16 * num_connection];
    float C[16][9] = {
        {1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00},
        {0.50, 0.50, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00},
        {0.50, -0.50, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00},
        {0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00},
        {0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00},
        {0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25},
        {0.25, -0.25, 0.25, 0.25, -0.25, 0.25, 0.25, -0.25, 0.25},
        {0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50},
        {0.50, 0.00, 0.00, -0.50, 0.00, 0.00, 0.50, 0.00, 0.00},
        {0.25, 0.25, 0.25, -0.25, -0.25, -0.25, 0.25, 0.25, 0.25},
        {0.25, -0.25, 0.25, -0.25, 0.25, -0.25, 0.25, -0.25, 0.25},
        {0.00, 0.00, 0.50, 0.00, 0.00, -0.50, 0.00, 0.00, 0.50},
        {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00},
        {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.50, 0.50},
        {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, -0.50, 0.50},
        {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00}};
    auto matmul = new MatMul();
    matmul->setup(
        filter_f32->data(), (float *)C, (float *)0 /*bias*/, fY /* output*/,
        1 /*batch*/, 1 /*batch_low*/, num_connection /*M*/, 9 /*K*/, 16 /*N*/,
        false /*relu*/, -1 /*relu_limits*/, 0, 0, true /*right-transpose*/,
        false /*left-transpose*/, false /*output_transpose*/, false /*h=dim*/);
    matmul->run();
    delete matmul;

    float fmaxTmp = findMaxabs(fY, num_connection * 16);
    winoRightShiftBits = calRightShiftNum(fmaxTmp, out_scale, in_scale, 8);

    if (attr.has_bias) {
      auto biasOp = cast<top::WeightOp>(op.getBias().getDefiningOp());
      auto bias_fp32 = biasOp.read<float>();
      int bias_len = bias_fp32->size();

      float bias_scale = 1.0 * (1 << winoRightShiftBits) / out_scale;
      float overflow_ratio =
          quantizeToInt16(bias_fp32->data(), bias_int16->data() + bias_len,
                          bias_len, bias_scale);

      int rightShiftDec = 2;
      while (overflow_ratio > 0.03 && winoRightShiftBits > 0) {
        winoRightShiftBits--;
        bias_scale = 1.0 * (1 << winoRightShiftBits) / out_scale;
        overflow_ratio =
            quantizeToInt16(bias_fp32->data(), bias_int16->data() + bias_len,
                            bias_len, bias_scale);
        rightShiftDec--;
      }
      rshift_v.push_back(winoRightShiftBits);
    }

    float winoScale = (1 << winoRightShiftBits) * in_scale / out_scale;
    std::vector<int8_t> wino_filter_int8(num_connection * 16);

    // quant wino weight
    for (int ii = 0; ii < num_connection * 16; ii++) {
      tmp = (float)floor(fY[ii] * winoScale + 0.5);
      tmp = (tmp > 127) ? 127 : ((tmp < -128) ? -128 : tmp);
      wino_filter_int8[ii] = (int8_t)tmp;
    }

    filter_shape[2] = 5;
    filter_shape[3] = 5;
    // append wino_filter_int8 content into filter_int8
    std::copy(wino_filter_int8.begin(), wino_filter_int8.end(),
              std::back_inserter(*filter_int8));
  }

  auto new_type = RankedTensorType::get(filter_shape, rewriter.getI8Type());
  auto new_filter =
      top::WeightOp::create(op, "filter_int8", *filter_int8, new_type);
  operands.push_back(new_filter);

  Value new_bias = op.getBias();
  if (attr.has_bias) {
    std::vector<int64_t> bias_shape = module::getShape(new_bias);
    if (use_wino) {
      bias_shape[0] *= 2;
    }
    auto new_type =
        RankedTensorType::get(bias_shape, rewriter.getIntegerType(16));
    new_bias = top::WeightOp::create(op, "bias_int16", *bias_int16, new_type);
  }
  operands.push_back(new_bias);
  for (auto &attr : op->getAttrs()) {
    attrs.push_back(attr);
  }
  attrs.push_back(rewriter.getNamedAttr(
      "rshift", rewriter.getI64ArrayAttr(ArrayRef<int64_t>{rshift_v})));
  attrs.push_back(rewriter.getNamedAttr(
      "quant_mode",
      tpu::RequantModeAttr::get(getContext(), tpu::RequantMode::OnlyShift)));
  attrs.push_back(
      rewriter.getNamedAttr("with_bias", rewriter.getBoolAttr(attr.has_bias)));
  auto newType = getQuantInt8Type(op.getOutput());
  if (op.getKernelShape().size() == 3) {
    rewriter.replaceOpWithNewOp<tpu::Conv3DOp>(op, newType, operands, attrs);
  } else {
    auto new_op = rewriter.replaceOpWithNewOp<tpu::Conv2DOp>(op, newType,
                                                             operands, attrs);
    if (use_wino) {
      new_op.setUseWinograd(2);
    }
  }
}

} // namespace bm1684
} // namespace tpu_mlir
